Skip to content

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented Aug 11, 2025

This PR adds Knowledge Distillation API to Keras,

Key Features

Core Components

  • Distiller: Main distillation model that combines teacher and student models
  • Strategies: Pluggable distillation strategies (LogitsDistillation, FeatureDistillation, MultiOutputDistillation)

Usage Examples

Basic Knowledge Distillation

import keras
from keras.distillation import Distiller, LogitsDistillation

# Create models
teacher = keras.Sequential([...])  # Large, pre-trained model
student = keras.Sequential([...])  # Smaller model to train

# Set up distillation
distiller = Distiller(
    teacher=teacher,
    student=student,
    strategies=[LogitsDistillation(temperature=3.0)],
    alpha=...
)

# Standard Keras workflow
distiller.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
distiller.fit(x_train, y_train, epochs=10)
predictions = distiller.predict(x_test)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new Knowledge Distillation API to Keras, designed to facilitate the efficient transfer of learned knowledge from larger, pre-trained "teacher" models to smaller "student" models. The API seamlessly integrates with Keras's existing training, evaluation, and prediction workflows, providing a flexible and extensible framework for various distillation techniques.

Highlights

  • New Distiller Model: A core Distiller class is added, which is a keras.Model subclass, enabling the combination and training of teacher and student models within the standard Keras workflow.
  • Pluggable Distillation Strategies: Introduces a BaseDistillationStrategy and three concrete implementations: LogitsDistillation (for softening logits), FeatureDistillation (for intermediate feature matching), and MultiOutputDistillation (for handling models with multiple outputs).
  • Configurable Loss Balancing: The Distiller allows specifying an alpha parameter to balance the contribution of the student's original loss and the distillation loss.
  • Automatic Teacher Freezing: The teacher model is automatically set to non-trainable (trainable=False) during the distillation process to prevent its weights from being updated.
  • Comprehensive Testing: New test files (distiller_test.py and strategies_test.py) are added to ensure the robustness and correctness of the new API, covering initialization, loss computation, and end-to-end workflows.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable strategies. My review has identified a few issues: a critical issue with the FeatureDistillation strategy which is not fully implemented, a high-severity issue with an unused temperature parameter in the Distiller class that could mislead users, and a medium-severity issue regarding a simplistic fallback for loss calculation in multi-output scenarios. Addressing these points will improve the robustness and clarity of this new API.

Comment on lines 229 to 236
if isinstance(y_pred, (list, tuple)):
# For multi-output models, use the first output for student
# loss
# This is a simplified approach for compatibility
if isinstance(y, (list, tuple)):
student_loss = self.student_loss_fn(y[0], y_pred[0])
else:
student_loss = self.student_loss_fn(y, y_pred[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic for calculating the student loss in _compute_loss for multi-output models is overly simplistic as it always defaults to using the first output (y_pred[0]). This might not align with user expectations for all multi-output scenarios and could lead to incorrect training behavior if model.compile() is not called with a loss that properly handles multiple outputs.

While the primary path using self.compiled_loss is correct, this fallback could be made more robust. Consider raising a more specific error if a multi-output model is used without a compiled loss, or clarifying this behavior more explicitly in the documentation.

@divyashreepathihalli divyashreepathihalli marked this pull request as draft August 11, 2025 23:44
@divyashreepathihalli
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a well-structured Knowledge Distillation API to Keras, a valuable addition for model compression and knowledge transfer. The implementation of the Distiller model and pluggable strategies is clean and the code is generally of high quality, with good documentation and examples.

I have identified a few significant issues that should be addressed to improve the robustness and completeness of the API:

  1. A critical flaw in FeatureDistillation limits its use to sequential models, which will prevent it from working with more complex architectures.
  2. The Distiller's loss computation for multi-output models has a bug in its fallback logic.
  3. The Distiller model cannot be correctly serialized and deserialized due to an incomplete get_config implementation.
  4. The tests for FeatureDistillation are not comprehensive enough to cover non-sequential model architectures.

Addressing these points will significantly enhance the reliability and usability of this new API. Overall, this is a great contribution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive Knowledge Distillation API to Keras, which is a great addition. The implementation is well-structured with a Distiller model and pluggable DistillationStrategy classes. The code is generally clean and includes a good set of tests.

I've identified a few key areas for improvement:

  1. Serialization: The Distiller and MultiOutputDistillation classes are not correctly configured for serialization, which will prevent models using them from being saved and loaded. This is a critical feature in Keras that needs to be addressed.
  2. Feature Extraction Robustness: The FeatureDistillation strategy uses a method for extracting intermediate layer features that is not robust and will fail for models with non-sequential architectures (e.g., ResNets). This limitation needs to be documented and ideally improved.
  3. Code Simplification: There's a small piece of unreachable code in the Distiller's loss computation that can be simplified.

Addressing these points will significantly improve the robustness and usability of this new API.

Comment on lines 283 to 287
if isinstance(y_pred, list) and len(y_pred) > 0:
# For multi-output, use first output for student loss
student_loss = self.student_loss_fn(y[0], y_pred[0])
else:
student_loss = self.student_loss_fn(y, y_pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The isinstance(y_pred, list) check on line 283 is redundant because y_pred is converted to a list on line 263. This makes the else block on line 286 unreachable. The logic can be simplified to directly use the first output for the student loss calculation.

                # Fallback: use student_loss_fn directly
                # For multi-output, use first output for student loss
                student_loss = self.student_loss_fn(y[0], y_pred[0])

@divyashreepathihalli divyashreepathihalli removed the request for review from hertschuh August 12, 2025 00:11
@divyashreepathihalli
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive and well-designed Knowledge Distillation API to Keras. The implementation is robust, featuring a flexible Distiller class and a set of pluggable distillation strategies that cover common use cases like logits and feature distillation, as well as multi-output models. The code is accompanied by extensive and thorough tests, which is excellent. My feedback includes a couple of suggestions to improve code style in the API files and to enhance the robustness of a test case by removing a broad exception handler. Overall, this is a high-quality contribution that will be a valuable addition to Keras.

@codecov-commenter
Copy link

codecov-commenter commented Aug 12, 2025

Codecov Report

❌ Patch coverage is 69.13183% with 96 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.53%. Comparing base (cc56474) to head (df07758).
⚠️ Report is 11 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/distillation/distiller.py 73.43% 30 Missing and 21 partials ⚠️
keras/src/distillation/distillation_loss.py 62.38% 27 Missing and 14 partials ⚠️
keras/api/_tf_keras/keras/distillation/__init__.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21572      +/-   ##
==========================================
- Coverage   82.59%   82.53%   -0.06%     
==========================================
  Files         572      576       +4     
  Lines       58314    58846     +532     
  Branches     9130     9220      +90     
==========================================
+ Hits        48166    48571     +405     
- Misses       7817     7907      +90     
- Partials     2331     2368      +37     
Flag Coverage Δ
keras 82.34% <69.13%> (-0.06%) ⬇️
keras-jax 63.24% <69.13%> (-0.06%) ⬇️
keras-numpy 57.36% <19.29%> (-0.29%) ⬇️
keras-openvino 34.25% <19.29%> (-0.08%) ⬇️
keras-tensorflow 63.99% <69.13%> (-0.05%) ⬇️
keras-torch 63.54% <69.13%> (-0.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Some quick comments on the API.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Divya!

Also, keras.tree is your friend.

if not isinstance(student_outputs, (list, tuple)):
student_outputs = [student_outputs]

if len(teacher_outputs) != len(student_outputs):
Copy link
Collaborator

@hertschuh hertschuh Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, we agreed to support any structure for outputs, including dicts and nested outputs.

So remove lines 50-53 and replace lines 55-61 with keras.tree.assert_same_structures(teacher_outputs, student_outputs). It's actually a lot simpler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't changed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"""
if self._teacher_feature_extractor is not None:
# Use efficient multi-output extractor (returns dict directly)
return self._teacher_feature_extractor(x, training=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... so in the case when there are LogitsDistillation instances, that means that the student model is run twice. Once in self.call to get y_pred and once here to get the other logits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated- the y_pred is now reused to avoid computation

Copy link
Collaborator

@hertschuh hertschuh Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is that y_pred is computed once in self.call and then another time when calling _student_feature_extractor. But let's not worry about that now.

Copy link
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit confused on whether sub-classed models are supported for feature distillation.

Comment on lines 367 to 369
except (ValueError, AttributeError):
# Fallback for subclassed models
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should throw a meaningful error here.

If I read this correctly, when you try to use LogitsDistillation strategies on a subclass model (and all of KerasHub models are subclass models), this will return None, and then later, you'll get the error line 422 f"Layer '{layer_name}' features not found in extracted ", which will be confusing to users.

Instead you should say here (or somewhere else) that LogitsDistillation is not compatible with subclass models.

One alternative would be to check that in validate_model_compatibility. Right now, you validate that the layer name is found, but you could also validate that it's not a subclass model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KerasHub backbones are all written in functional model style. Logits distillation would work for subclass models but would not for feature distillation because model.inputs, model.outputs attribute would not be accessible.

Added checks in distiller and removed silent fallback

f"targeting teacher layer "
f"'{strategy.teacher_layer_name}' and student "
f"layer '{strategy.student_layer_name}'. This can "
f"happen with subclassed models that haven't "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't it happen with all subclass models always? Even when they're built? (based on line 369)

Based on my comment line 369, can't we catch this error earlier rather than try to guess that any ValueError that's happening here has this as the cause?

Comment on lines 596 to 597
# Ensure distillation_loss is a scalar
if len(distillation_loss.shape) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an error (checking strategy_loss within the for loop)? Aren't strategies supposed to output scalars?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raising an error inside the loop if strategy returns a non scalar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants